{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms, models\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据预处理\n",
    "transform = transforms.Compose([\n",
    "    transforms.RandomResizedCrop(224),# 对图像进行随机的crop以后再resize成固定大小\n",
    "    transforms.RandomRotation(20), # 随机旋转角度\n",
    "    transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转\n",
    "    transforms.ToTensor() \n",
    "])\n",
    " \n",
    "# 读取数据\n",
    "root = 'image'\n",
    "train_dataset = datasets.ImageFolder(root + '/train', transform)\n",
    "test_dataset = datasets.ImageFolder(root + '/test', transform)\n",
    " \n",
    "# 导入数据\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['cat', 'dog']\n",
      "{'cat': 0, 'dog': 1}\n"
     ]
    }
   ],
   "source": [
    "classes = train_dataset.classes\n",
    "classes_index = train_dataset.class_to_idx\n",
    "print(classes)\n",
    "print(classes_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VGG(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (3): ReLU(inplace=True)\n",
      "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (6): ReLU(inplace=True)\n",
      "    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (8): ReLU(inplace=True)\n",
      "    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): ReLU(inplace=True)\n",
      "    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (13): ReLU(inplace=True)\n",
      "    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (15): ReLU(inplace=True)\n",
      "    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (18): ReLU(inplace=True)\n",
      "    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (20): ReLU(inplace=True)\n",
      "    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (22): ReLU(inplace=True)\n",
      "    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (25): ReLU(inplace=True)\n",
      "    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (27): ReLU(inplace=True)\n",
      "    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (29): ReLU(inplace=True)\n",
      "    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
      "  (classifier): Sequential(\n",
      "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Dropout(p=0.5, inplace=False)\n",
      "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (4): ReLU(inplace=True)\n",
      "    (5): Dropout(p=0.5, inplace=False)\n",
      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model = models.vgg16(pretrained = True)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 如果我们想只训练模型的全连接层\n",
    "for param in model.parameters():\n",
    "    param.requires_grad = False\n",
    "    \n",
    "# 构建新的全连接层\n",
    "model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),\n",
    "                                       torch.nn.ReLU(),\n",
    "                                       torch.nn.Dropout(p=0.5),\n",
    "                                       torch.nn.Linear(100, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "LR = 0.0003\n",
    "# 定义代价函数\n",
    "entropy_loss = nn.CrossEntropyLoss()\n",
    "# 定义优化器\n",
    "optimizer = optim.Adam(model.parameters(), LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    for i, data in enumerate(train_loader):\n",
    "        # 获得数据和对应的标签\n",
    "        inputs, labels = data\n",
    "        # 获得模型预测结果，（64，10）\n",
    "        out = model(inputs)\n",
    "        # 交叉熵代价函数out(batch,C),labels(batch)\n",
    "        loss = entropy_loss(out, labels)\n",
    "        # 梯度清0\n",
    "        optimizer.zero_grad()\n",
    "        # 计算梯度\n",
    "        loss.backward()\n",
    "        # 修改权值\n",
    "        optimizer.step()\n",
    "\n",
    "\n",
    "def test():\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    for i, data in enumerate(test_loader):\n",
    "        # 获得数据和对应的标签\n",
    "        inputs, labels = data\n",
    "        # 获得模型预测结果\n",
    "        out = model(inputs)\n",
    "        # 获得最大值，以及最大值所在的位置\n",
    "        _, predicted = torch.max(out, 1)\n",
    "        # 预测正确的数量\n",
    "        correct += (predicted == labels).sum()\n",
    "    print(\"Test acc: {0}\".format(correct.item() / len(test_dataset)))\n",
    "    \n",
    "    correct = 0\n",
    "    for i, data in enumerate(train_loader):\n",
    "        # 获得数据和对应的标签\n",
    "        inputs, labels = data\n",
    "        # 获得模型预测结果\n",
    "        out = model(inputs)\n",
    "        # 获得最大值，以及最大值所在的位置\n",
    "        _, predicted = torch.max(out, 1)\n",
    "        # 预测正确的数量\n",
    "        correct += (predicted == labels).sum()\n",
    "    print(\"Train acc: {0}\".format(correct.item() / len(train_dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0\n",
      "Test acc: 0.8\n",
      "Train acc: 0.8325\n",
      "epoch: 1\n",
      "Test acc: 0.88\n",
      "Train acc: 0.88\n",
      "epoch: 2\n",
      "Test acc: 0.835\n",
      "Train acc: 0.87\n",
      "epoch: 3\n",
      "Test acc: 0.855\n",
      "Train acc: 0.8825\n",
      "epoch: 4\n",
      "Test acc: 0.865\n",
      "Train acc: 0.8575\n",
      "epoch: 5\n",
      "Test acc: 0.89\n",
      "Train acc: 0.9075\n",
      "epoch: 6\n",
      "Test acc: 0.88\n",
      "Train acc: 0.905\n",
      "epoch: 7\n",
      "Test acc: 0.865\n",
      "Train acc: 0.9175\n",
      "epoch: 8\n",
      "Test acc: 0.88\n",
      "Train acc: 0.8875\n",
      "epoch: 9\n",
      "Test acc: 0.895\n",
      "Train acc: 0.905\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(0, 10):\n",
    "    print('epoch:',epoch)\n",
    "    train()\n",
    "    test()\n",
    "    \n",
    "torch.save(model.state_dict(), 'cat_dog_fc.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
